# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import math
import re
import textwrap

import torch

from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.distributions.util import broadcast_shape, sum_rightmost
from pyro.ops.special import log_binomial

from .. import settings
from . import constraints


def _clamp_by_zero(x):
    # works like clamp(x, min=0) but has grad at 0 is 0.5
    return (x.clamp(min=0) + x - x.clamp(max=0)) / 2


class Beta(torch.distributions.Beta, TorchDistributionMixin):
    def conjugate_update(self, other):
        """
        EXPERIMENTAL.
        """
        assert isinstance(other, Beta)
        concentration1 = self.concentration1 + other.concentration1 - 1
        concentration0 = self.concentration0 + other.concentration0 - 1
        updated = Beta(concentration1, concentration0)

        def _log_normalizer(d):
            x = d.concentration1
            y = d.concentration0
            return (x + y).lgamma() - x.lgamma() - y.lgamma()

        log_normalizer = (
            _log_normalizer(self) + _log_normalizer(other) - _log_normalizer(updated)
        )
        return updated, log_normalizer


class Binomial(torch.distributions.Binomial, TorchDistributionMixin):
    # EXPERIMENTAL threshold on total_count above which sampling will use a
    # clamped Poisson approximation for Binomial samples. This is useful for
    # sampling very large populations.
    approx_sample_thresh = math.inf

    # EXPERIMENTAL If set to a positive value, the .log_prob() method will use
    # a shifted Sterling's approximation to the Beta function, reducing
    # computational cost from 3 lgamma() evaluations to 4 log() evaluations
    # plus arithmetic. Recommended values are between 0.1 and 0.01.
    approx_log_prob_tol = 0.0

    def sample(self, sample_shape=torch.Size()):
        if self.approx_sample_thresh < math.inf:
            exact = self.total_count <= self.approx_sample_thresh
            if not exact.all():
                # Approximate large counts with a moment-matched clamped Poisson.
                with torch.no_grad():
                    shape = self._extended_shape(sample_shape)
                    p = self.probs
                    q = 1 - self.probs
                    mean = torch.min(p, q) * self.total_count
                    variance = p * q * self.total_count
                    shift = (mean - variance).round()
                    result = torch.poisson(variance.expand(shape))
                    result = torch.min(result + shift, self.total_count)
                    sample = torch.where(p < q, result, self.total_count - result)
                # Draw exact samples for remaining items.
                if exact.any():
                    total_count = torch.where(
                        exact, self.total_count, torch.zeros_like(self.total_count)
                    )
                    exact_sample = torch.distributions.Binomial(
                        total_count, self.probs, validate_args=False
                    ).sample(sample_shape)
                    sample = torch.where(exact, exact_sample, sample)
                return sample
        return super().sample(sample_shape)

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)

        n = self.total_count
        k = value
        # k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p)
        #     (case logit < 0)              = k * logit - n * log1p(e^logit)
        #     (case logit > 0)              = k * logit - n * (log(p) - log(1 - p)) + n * log(p)
        #                                   = k * logit - n * logit - n * log1p(e^-logit)
        #     (merge two cases)             = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|)
        normalize_term = n * (
            _clamp_by_zero(self.logits) + self.logits.abs().neg().exp().log1p()
        )
        return (
            k * self.logits
            - normalize_term
            + log_binomial(n, k, tol=self.approx_log_prob_tol)
        )


@settings.register(
    "binomial_approx_sample_thresh", __name__, "Binomial.approx_sample_thresh"
)
def _validate_thresh(thresh):
    assert isinstance(thresh, float)
    assert 0 < thresh


@settings.register(
    "binomial_approx_log_prob_tol", __name__, "Binomial.approx_log_prob_tol"
)
def _validate_tol(tol):
    assert isinstance(tol, float)
    assert 0 <= tol


# This overloads .log_prob() and .enumerate_support() to speed up evaluating
# log_prob on the support of this variable: we can completely avoid tensor ops
# and merely reshape the self.logits tensor. This is especially important for
# Pyro models that use enumeration.
class Categorical(torch.distributions.Categorical, TorchDistributionMixin):
    arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}

    def log_prob(self, value):
        if getattr(value, "_pyro_categorical_support", None) == id(self):
            # Assume value is a reshaped torch.arange(event_shape[0]).
            # In this case we can call .reshape() rather than torch.gather().
            if not torch._C._get_tracing_state():
                if self._validate_args:
                    self._validate_sample(value)
                assert value.size(0) == self.logits.size(-1)
            logits = self.logits
            if logits.dim() <= value.dim():
                logits = logits.reshape(
                    (1,) * (1 + value.dim() - logits.dim()) + logits.shape
                )
            if not torch._C._get_tracing_state():
                assert logits.size(-1 - value.dim()) == 1
            return logits.transpose(-1 - value.dim(), -1).squeeze(-1)
        return super().log_prob(value)

    def enumerate_support(self, expand=True):
        result = super().enumerate_support(expand=expand)
        if not expand:
            result._pyro_categorical_support = id(self)
        return result


class Dirichlet(torch.distributions.Dirichlet, TorchDistributionMixin):
    @staticmethod
    def infer_shapes(concentration):
        batch_shape = concentration[:-1]
        event_shape = concentration[-1:]
        return batch_shape, event_shape

    def conjugate_update(self, other):
        """
        EXPERIMENTAL.
        """
        assert isinstance(other, Dirichlet)
        concentration = self.concentration + other.concentration - 1
        updated = Dirichlet(concentration)

        def _log_normalizer(d):
            c = d.concentration
            return c.sum(-1).lgamma() - c.lgamma().sum(-1)

        log_normalizer = (
            _log_normalizer(self) + _log_normalizer(other) - _log_normalizer(updated)
        )
        return updated, log_normalizer


class Gamma(torch.distributions.Gamma, TorchDistributionMixin):
    def conjugate_update(self, other):
        """
        EXPERIMENTAL.
        """
        assert isinstance(other, Gamma)
        concentration = self.concentration + other.concentration - 1
        rate = self.rate + other.rate
        updated = Gamma(concentration, rate)

        def _log_normalizer(d):
            c = d.concentration
            return d.rate.log() * c - c.lgamma()

        log_normalizer = (
            _log_normalizer(self) + _log_normalizer(other) - _log_normalizer(updated)
        )
        return updated, log_normalizer


class Geometric(torch.distributions.Geometric, TorchDistributionMixin):
    # TODO: move upstream
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        return (-value - 1) * torch.nn.functional.softplus(self.logits) + self.logits


class LogNormal(torch.distributions.LogNormal, TorchDistributionMixin):
    def __init__(self, loc, scale, validate_args=None):
        base_dist = Normal(loc, scale)
        # This differs from torch.distributions.LogNormal only in that base_dist is
        # a pyro.distributions.Normal rather than a torch.distributions.Normal.
        super(torch.distributions.LogNormal, self).__init__(
            base_dist,
            torch.distributions.transforms.ExpTransform(),
            validate_args=validate_args,
        )

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(LogNormal, _instance)
        return super(torch.distributions.LogNormal, self).expand(
            batch_shape, _instance=new
        )


class LowRankMultivariateNormal(
    torch.distributions.LowRankMultivariateNormal, TorchDistributionMixin
):
    @staticmethod
    def infer_shapes(loc, cov_factor, cov_diag):
        event_shape = loc[-1:]
        batch_shape = broadcast_shape(loc[:-1], cov_factor[:-2], cov_diag[:-1])
        return batch_shape, event_shape


class MultivariateNormal(
    torch.distributions.MultivariateNormal, TorchDistributionMixin
):
    @staticmethod
    def infer_shapes(
        loc, covariance_matrix=None, precision_matrix=None, scale_tril=None
    ):
        batch_shape, event_shape = loc[:-1], loc[-1:]
        for matrix in [covariance_matrix, precision_matrix, scale_tril]:
            if matrix is not None:
                batch_shape = broadcast_shape(batch_shape, matrix[:-2])
        return batch_shape, event_shape


class Multinomial(torch.distributions.Multinomial, TorchDistributionMixin):
    def infer_shapes(total_count=None, probs=None, logits=None):
        tensor = probs if logits is None else logits
        batch_shape, event_shape = tensor[:-1], tensor[-1:]
        if isinstance(total_count, tuple):
            batch_shape = broadcast_shape(batch_shape, total_count)
        return batch_shape, event_shape


class Normal(torch.distributions.Normal, TorchDistributionMixin):
    pass


class OneHotCategorical(torch.distributions.OneHotCategorical, TorchDistributionMixin):
    @staticmethod
    def infer_shapes(probs=None, logits=None):
        tensor = probs if logits is None else logits
        event_shape = tensor[-1:]
        batch_shape = tensor[:-1]
        return batch_shape, event_shape


class Poisson(torch.distributions.Poisson, TorchDistributionMixin):
    def __init__(self, rate, *, is_sparse=False, validate_args=None):
        self.is_sparse = is_sparse
        super().__init__(rate, validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(Poisson, _instance)
        new = super().expand(batch_shape, _instance=new)
        new.is_sparse = self.is_sparse
        return new

    def log_prob(self, value):
        if not self.is_sparse:
            return super().log_prob(value)
        if self._validate_args:
            self._validate_sample(value)
        rate, value, nonzero = torch.broadcast_tensors(self.rate, value, value > 0)
        sparse_rate = rate[nonzero]
        sparse_value = value[nonzero]
        return (
            torch.zeros_like(rate).masked_scatter(
                nonzero,
                (sparse_rate.log() * sparse_value) - (sparse_value + 1).lgamma(),
            )
            - rate
        )


class Independent(torch.distributions.Independent, TorchDistributionMixin):
    @staticmethod
    def infer_shapes(**kwargs):
        raise NotImplementedError

    @property
    def _validate_args(self):
        return self.base_dist._validate_args

    @_validate_args.setter
    def _validate_args(self, value):
        self.base_dist._validate_args = value

    def conjugate_update(self, other):
        """
        EXPERIMENTAL.
        """
        n = self.reintepreted_batch_ndims
        updated, log_normalizer = self.base_dist.conjugate_update(other.to_event(-n))
        updated = updated.to_event(n)
        log_normalizer = sum_rightmost(log_normalizer, n)
        return updated, log_normalizer


class Uniform(torch.distributions.Uniform, TorchDistributionMixin):
    def __init__(self, low, high, validate_args=None):
        self._unbroadcasted_low = low
        self._unbroadcasted_high = high
        super().__init__(low, high, validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(Uniform, _instance)
        new = super().expand(batch_shape, _instance=new)
        new._unbroadcasted_low = self._unbroadcasted_low
        new._unbroadcasted_high = self._unbroadcasted_high
        return new

    @constraints.dependent_property(is_discrete=False, event_dim=0)
    def support(self):
        return constraints.interval(self._unbroadcasted_low, self._unbroadcasted_high)


def _cat_docstrings(*docstrings):
    result = "\n".join(textwrap.dedent(s.lstrip("\n")) for s in docstrings)
    result = re.sub("\n\n+", "\n\n", result)
    # Drop torch-specific lines.
    result = "".join(
        line for line in result.splitlines(keepends=True) if "xdoctest" not in line
    )
    return result


# Add static imports to help mypy.
__all__ = [  # noqa: F822
    "Bernoulli",
    "Beta",
    "Binomial",
    "Categorical",
    "Cauchy",
    "Chi2",
    "ContinuousBernoulli",
    "Dirichlet",
    "ExponentialFamily",
    "Exponential",
    "FisherSnedecor",
    "Gamma",
    "Geometric",
    "Gumbel",
    "HalfCauchy",
    "HalfNormal",
    "Independent",
    "Kumaraswamy",
    "Laplace",
    "LKJCholesky",
    "LogNormal",
    "LogisticNormal",
    "LowRankMultivariateNormal",
    "MixtureSameFamily",
    "Multinomial",
    "MultivariateNormal",
    "NegativeBinomial",
    "Normal",
    "OneHotCategorical",
    "OneHotCategoricalStraightThrough",
    "Pareto",
    "Poisson",
    "RelaxedBernoulli",
    "RelaxedOneHotCategorical",
    "StudentT",
    "TransformedDistribution",
    "Uniform",
    "VonMises",
    "Weibull",
    "Wishart",
]

# Programmatically load all distributions from PyTorch,
# updating __all__ to include any new distributions.
for _name, _Dist in torch.distributions.__dict__.items():
    if not isinstance(_Dist, type):
        continue
    if not issubclass(_Dist, torch.distributions.Distribution):
        continue
    if _Dist is torch.distributions.Distribution:
        continue

    try:
        _PyroDist = locals()[_name]
    except KeyError:
        _PyroDist = type(_name, (_Dist, TorchDistributionMixin), {})
        _PyroDist.__module__ = __name__
        locals()[_name] = _PyroDist

    _PyroDist.__doc__ = """
    Wraps :class:`{}.{}` with
    :class:`~pyro.distributions.torch_distribution.TorchDistributionMixin`.

    """.format(
        _Dist.__module__, _Dist.__name__
    )
    _PyroDist.__doc__ = _cat_docstrings(_PyroDist.__doc__, _Dist.__doc__)
    __all__.append(_name)
__all__ = sorted(set(__all__))


# Create sphinx documentation.
__doc__ = "\n\n".join(
    [
        """
    {0}
    ----------------------------------------------------------------
    .. autoclass:: pyro.distributions.{0}
    """.format(
            _name
        )
        for _name in sorted(__all__)
        # Work around sphinx autodoc error in case two InverseGamma's are defined:
        # "duplicate object description of pyro.distributions.InverseGamma"
        if _name != "InverseGamma"
    ]
)
